# -*- coding: utf-8 -*-
"""
Created on Tue Sep 20 13:42:34 2022

@author: obiri
"""

import mpctools as mpc
import casadi
import numpy as np
import matplotlib.pyplot as plt
import time
import sys
from scipy import linalg
from numpy import random

from MHE_model_normalised import reactor_tank, reactor_tank_uncert, measurement_xy_model, p_n, width,lb
from MHE_model_normalised import p_n as p_true
from MHE_model_normalised import pn as p_nominal
time.clock = time.time



random.seed(927) # Seed random number generator.

doPlots = True                      
                                    #Ans:Runs the plot functions for a given data set and saves them all in graphics files                          
fullInformation = False # True for full information estimation, False for MHE.   

 
Nt = 50#window
Delta = 60 # Time step
# Nsim = 60
Nsim = 50
tplot = np.arange(Nsim+1)*Delta 


# Nx = 16   # Number of system states
# Nx = 16 + 18   # Number of system states
Nx = 16 + 5   # Number of system states
Nu = 8   # Number of system inputs
 
Ny = 7  # Number of system outputs        
Nw = Nx  # Number of process disturbances
Nv = Ny  # Number of measurement noise

sigma_p = 50

##varying the S.D of the noise for each of the state estimates

sigma_v = 0.0095*np.array([0.05,  0.05,  0.05, 0.009, 0.009, 0.0095, 0.009])
sigma_w = 0.5*np.array([0.05, 0.5*1e6, 0.08, 0.0005, 0.005, 0.0008, 0.003, 0.005, 1e6, 0.015, 0.001, 0.003, 0.00095, 0.005, 0.003, 0.0001,  
                        0.05, 0.05, 0.0001, 0.05, 0.05 ])

# # Make covariance matrices.
R = np.diag((sigma_v*np.ones((Nv,)))**2) 
Q = np.diag((sigma_w*np.ones((Nw,)))**2) 
P = np.diag((sigma_p*np.ones((Nx,)))**2) # Covariance for prior.
# P = np.eye((Nx))

#FOR 16 STATES
x0 = np.array([3.22456499e+10, 8.74344999e+10, 2.21584610e+03, 3.56206935e+00,
        6.72674220e+01, 5.03509229e+00, 6.80219613e+01, 
        2.58024685e+09, 6.99637296e+09, 1.77308565e+03, 2.85031259e+00,
        5.38263469e+01, 4.02900269e+00, 5.44301175e+01, 
        3.69613398e+01, 5.44301175e+01])


x0_state_norm = 1.15*np.array([0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.3, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.3, 0.25, 0.5])
# x0_param_norm = np.array([   0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5])
x0_param_norm = (p_true - 0.1*p_nominal)/(1.4*p_nominal)
# x0_param_norm = np.array([ 0.5, 0.5, 0.5, 0.5, 0.5])

# x0_aug_norm = np.concatenate((x0_state_norm, x0_param_norm), axis = 0)
# x_0 = 1.2*x0_aug_norm #guess for MHE
x0_aug_norm = np.concatenate((x0_state_norm, x0_param_norm), axis = 0)
x_0_s = 1.2*x0_aug_norm[0:16] #guess for MHE
# x_0_p = x0_aug_norm[16:]
x_0_p = (p_nominal - 0.1*p_nominal)/(1.4*p_nominal)
x_0 = np.concatenate((x_0_s, x_0_p), axis = 0)




#VARYING THE FLOWRATE TO INCREASE THE OSCILLATIONS IN THE PLOTS
u = np.zeros((Nsim,Nu))
us=np.array([2.16100550e+01, 2.16155565e+01, 5.50143837e-03, 2.16100550e+01,
        2.19999850e+03, 1.30642193e+01, 4.06999944e+01, 2.16100550e+01])
us0=0.98*us
us1=0.99*us
us2=1.01*us
us3=1*us
us4=1*us
us_o = np.array([us0, us1, us2, us3, us4])
for i in range(us_o.shape[1]):
    u[0:10, i] = us_o[0,i]
    u[10:20,i] = us_o[1,i]
    u[20:30,i] = us_o[2,i]
    u[30:40,i] = us_o[3,i]
    u[40:,i]   = us_o[4,i]
  
    
# Make a simulator.
model_reactor_simulator = mpc.DiscreteSimulator(reactor_tank_uncert, Delta, [Nx,Nu,Nw], ["x","u","w"])    
# Convert continuous-time f to explicit discrete-time F with RK4.
F = mpc.getCasadiFunc(reactor_tank,[Nx,Nu,Nw],["x","u","w"],"F",rk4=True,Delta=Delta,M=1)
# F1 = mpc.getCasadiFunc(reactor_tank_uncert,[Nx,Nu,Nw],["x","u","w"],"F",rk4=True,Delta=Delta,M=1)

H = mpc.getCasadiFunc(measurement_xy_model,[Nx],["x"],"H")


# Define stage costs.
def lfunc(w,v):
    return mpc.mtimes(w.T,linalg.inv(Q),w)+mpc.mtimes(v.T,linalg.inv(R),v)
l = mpc.getCasadiFunc(lfunc,[Nw,Nv],["w","v"],"l")
def lxfunc(x):
    return mpc.mtimes(x.T,linalg.inv(P),x)
lx = mpc.getCasadiFunc(lxfunc,[Nx],["x"],"lx")

xs = np.array([3.27367434e+10, 8.28831113e+10, 2.71741538e+03, 5.78604334e+00,
        6.37656737e+01, 4.83137928e+00, 6.23866477e+01, 
        2.61960613e+09, 6.63233551e+09, 2.17448574e+03, 4.63001301e+00,
        5.10255242e+01, 3.86608729e+00, 4.99220220e+01, 
        3.70349577e+01, 4.99220218e+01])
xs_aug = np.concatenate((xs, p_n))

# sigma_v2 = 1e-3*np.array([ 0.484975, 0.554888, 0.315405, 0.115617, 0.554891, 0.54214, 0.590303 ])
# sigma_wa = 1e-3*np.array([0.484999, 0.554913, 0.315424, 0.115631, 0.554916, 0.542165, 0.590329, 0.484975, 0.554888, 0.315405, 0.115617, 0.554891, 0.54214, 0.590303, 0.498012, 0.590303])
# # sigma_wb = 1e-7*np.array([  0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.05, 0.05, 0.001, 0.001, 0.001, 0.001 ])
# sigma_wb = 1e-3*np.array([  0.001, 0.001, 0.001, 0.001, 0.001 ])

sigma_v2 = 0.0095*np.array([0.05,  0.05,  0.05, 0.009, 0.009, 0.0095, 0.009])
sigma_wa = 0.5*np.array([5*1e6, 0.5*1e5, 0.9, 0.0005, 0.005, 0.0008, 0.003, 0.005, 1e6, 0.015, 0.0004, 0.003, 0.00075, 0.005, 0.003, 0.0001])
sigma_wb = 1e-3*np.array([  0.001, 0.001, 0.0001, 0.001, 0.001 ])



sigma_w2 = np.concatenate((sigma_wa,sigma_wb), axis = 0)

np.random.seed(3)
w = sigma_w2*random.randn(Nsim,Nw) #defining the measurement noise
np.random.seed(4)
v = sigma_v2*random.randn(Nsim,Nv) #definig the process noise


usim = u 
xsim = np.zeros((Nsim+1,Nx))

# xsim[0,:] = x0_norm 
xsim[0,:] = x0_aug_norm
yclean = np.zeros((Nsim, Ny))
ysim = np.zeros((Nsim, Ny))

# Simulate the process dynamics
for t in range(Nsim):
    yclean[t,:] = measurement_xy_model(xsim[t]) 
    ysim[t,:] = yclean[t,:] + v[t,:] # Adding measurement noise to the measurement
    xsim[t+1,:] = model_reactor_simulator.sim(xsim[t,:],usim[t,:],w[t,:]) #Adding process noise to the states


# Now do estimation.
xhat_ = np.zeros((Nsim+1,Nx))   
xhat = np.zeros((Nsim,Nx))
yhat = np.zeros((Nsim,Ny))
vhat = np.zeros((Nsim,Nv))      
what = np.zeros((Nsim,Nw))

x0bar = x_0
xhat[0,:] = x0bar
guess = {}


solveroptions = {
            'linear_solver':'mumps' 
            }
        
totaltime = -time.time()         
for t in range(1, Nsim):
    # Define sizes of everything.    
    N = {"x":Nx, "y":Ny, "u":Nu}
    if fullInformation:          
        N["t"] = t
        tmin = 0
    else:
        N["t"] = min(t,Nt)
        tmin = max(0,t - Nt)
    tmax = t+1        
    #lb = {"x":np.zeros((N["t"] + 1,Nx))} 
    
    #UPPER AND LOWER BOUND OF THE STATES:
    ub_val_x = 1.0*np.ones(16)
    ub_val_p = 0.720*np.ones(5)

    lb_val_x = 0.1*np.ones(16)
    lb_val_p = 0.714*np.ones(5)

    ub_val = np.concatenate((ub_val_x, ub_val_p))
    lb_val = np.concatenate((lb_val_x, lb_val_p))


    
    
    lb = {'x': lb_val}
    ub = {'x': ub_val}
    

    buildtime = -time.time()
    solver = mpc.nmhe(f=F, h=H, u=usim[tmin:tmax-1,:],  
    # solver = mpc.nmhe(f=F1, h=H, u=usim[tmin:tmax-1,:],  
                      y=ysim[tmin:tmax,:], l=l, N=N, 
                      verbosity=0,
                      lb=lb, ub=ub, guess=guess,Delta=Delta)
    buildtime += time.time()
    solvetime = -time.time()
    sol = mpc.callSolver(solver)
    solvetime += time.time()
    print ("%3d (%5.3g s build, %5.3g s solve): %s"
           % (t, buildtime, solvetime, sol["status"]))
    if sol["status"] != "Solve_Succeeded":
        break
    xhat[t,:] = sol["x"][-1,...] # This is xhat( t  | t )
    yhat[t,:] = measurement_xy_model(xhat[t,:])    
    vhat[t,:] = sol["v"][-1,...]
    if t > 0:
        what[t-1,:] = sol["w"][-1,...]
    
    # Apply model function to get xhat(t+1 | t )
    xhat_[t+1,:] = np.squeeze(F(xhat[t,:], usim[t,:], np.zeros((Nw,))))
    
    # Save stuff to use as a guess. Cycle the guess.
    guess = {}
    for k in set(["x","w","v"]).intersection(sol.keys()):
        guess[k] = sol[k].copy()
    
    # Do some different things if not using full information estimation.    
    if not fullInformation and t + 1 > Nt:
        for k in guess.keys():
            guess[k] = guess[k][1:,...] # Get rid of oldest measurement.


#============================================================================#            
        # Do EKF to update prior covariance, but don't take EKF state. Remove if arrival cost is not needed
        [P, x0bar, _, _] = mpc.ekf(F,H,x=sol["x"][0,...],
            u=usim[tmin,:],w=sol["w"][0,...],y=ysim[tmin,:],P=P,Q=Q,R=R)
        
        # Need to redefine arrival cost.
        def lxfunc(x):
            return mpc.mtimes(x.T,linalg.inv(P),x)
        lx = mpc.getCasadiFunc(lxfunc,[Nx],["x"],"lx")
#============================================================================# 
     # Add final guess state for new time point.
    for k in guess.keys():
        guess[k] = np.concatenate((guess[k],guess[k][-1:,...]))

totaltime += time.time()
print ( "Simulation took %.5g s." % totaltime )


x_actual_hat = np.zeros((Nsim,Nx))
x_actual = np.zeros((Nsim,Nx))


# # scaling states and parameters for the ODE -- x = x_scale * delta + x_min

lb = 0.1 # lower bound of normalized model: lb*xs
ub = 1.5 # upper bound of normalized model: ub*xs
width = ub-lb
###delta = (ub-lb)*xs = width*xs

delta = np.array([width*xs_aug[0], width*xs_aug[1], width*xs_aug[2], width*xs_aug[3], width*xs_aug[4], width*xs_aug[5],
                  width*xs_aug[6], width*xs_aug[7], width*xs_aug[8], width*xs_aug[9], width*xs_aug[10], width*xs_aug[11],
                  width*xs_aug[12], width*xs_aug[13], width*xs_aug[14], width*xs_aug[15],     
                  
                  # width*xs_aug[16], width*xs_aug[17], width*xs_aug[18], width*xs_aug[19], width*xs_aug[20], width*xs_aug[21],width*xs_aug[22], width*xs_aug[23], width*xs_aug[24],
                  # width*xs_aug[25], width*xs_aug[26], width*xs_aug[27]  ])
                  width*xs_aug[16], width*xs_aug[17], width*xs_aug[18], width*xs_aug[19], width*xs_aug[20] ])


# minvalue lb*xs
minvalue = np.array([lb*xs_aug[0], lb*xs_aug[1], lb*xs_aug[2], lb*xs_aug[3], lb*xs_aug[4], lb*xs_aug[5],
                     lb*xs_aug[6], lb*xs_aug[7], lb*xs_aug[8], lb*xs_aug[9], lb*xs_aug[10], lb*xs_aug[11],
                     lb*xs_aug[12],lb*xs_aug[13],lb*xs_aug[14],lb*xs_aug[15] ,       
                     
                     # lb*xs_aug[16], lb*xs_aug[17], lb*xs_aug[18], lb*xs_aug[19], lb*xs_aug[20], lb*xs_aug[21], lb*xs_aug[22], lb*xs_aug[23], lb*xs_aug[24], lb*xs_aug[25],
                     # lb*xs_aug[26], lb*xs_aug[27]  ])
                     lb*xs_aug[16], lb*xs_aug[17], lb*xs_aug[18], lb*xs_aug[19], lb*xs_aug[20] ])


# Recover the actual states and estimates based on the scaled states and estimates
for i in range(Nx):
    x_actual_hat [:,i] = xhat [:,i] * delta[i] + minvalue[i]
    x_actual [:,i] = xsim [:Nsim,i] * delta[i] + minvalue[i]    




np.save("x_actual_C3.npy",x_actual)
np.save("x_hat_C3.npy",x_actual_hat)   




##SPECIFYING UNSCALED STATE VALUES FOR ACTUAL AND PREDICTED STATES
predicted_x_actual_ = x_actual_hat[:50,:16]
x_actual_           = x_actual[:50,:16]

##FUNCTION FOR  STATE RMSE CALCULATION USING UNSCALED VALUES
def calculate_nrmse_state_unscaled(x_act, x_pred):
    error = ((x_pred - x_act) / x_act )**2 # Element-wise relative error
    rmse = np.sqrt(np.mean(error, axis=1))  # RMSE for each simulation
    nrmse = np.mean(rmse)  # Mean of RMSE across simulations
    return nrmse

nrmse3 = calculate_nrmse_state_unscaled(x_actual_, predicted_x_actual_)
print("STATE NRMSE_unscaled:", nrmse3)




#-----------------------------Considering all 18 parameters, (both estimated and not)----------------------------------------#
##--FUNCTION FOR PARAMETER RMSE CALCULATION USING SCALED VALUES--(all 18 parameters)##
def rmse_param_normalized_all(x_act, x_pred):
    error = ((x_pred - x_act))**2 # Element-wise relative error
    rmse = np.sqrt(np.mean(error, axis=1))  # RMSE for each simulation
    nrmse = np.mean(rmse)  # Mean of RMSE across simulations
    return nrmse
xsim_fixed_pars = 0.65*np.ones([50,13])
xhat_fixed_pars = xhat[0,16]*np.ones([50,13])
actual_pars = np.concatenate((xsim[:50,16:], xsim_fixed_pars),axis=1)
estimated_pars = np.concatenate((xhat[:50,16:], xhat_fixed_pars), axis=1)
rmse_value_p_normalized = rmse_param_normalized_all(actual_pars, estimated_pars)
print("RMSE_p_scaled_all_18:", rmse_value_p_normalized)




##FUNCTION FOR PARAMETER RMSE CALCULATION USING UNSCALED VALUES--(all 18 parameters)
def rmse_param_unscaled_all(x_act, x_pred):
    error = ((x_pred - x_act) / x_act )**2 # Element-wise relative error
    rmse = np.sqrt(np.mean(error, axis=1))  # RMSE for each simulation
    nrmse = np.mean(rmse)  # Mean of RMSE across simulations
    return nrmse
rmse_value_p_unscaled_all_18 = rmse_param_unscaled_all(x_actual[:50, 16:],  x_actual_hat[:50, 16:])
print("RMSE_p_unscaled_all_18:", rmse_value_p_unscaled_all_18)





# def rmse_param(x_act, x_pred):
#     error = ((x_pred - x_act))**2 # Element-wise relative error
#     rmse = np.sqrt(np.mean(error, axis=1))  # RMSE for each simulation
#     nrmse = np.mean(rmse)  # Mean of RMSE across simulations
#     return nrmse
# predicted_values_p = x_actual_hat[:50,16:]
# actual_values_p = x_actual[:50,16:]

# # def rmse_param(x_act, x_pred):
# #     error = ((x_pred - x_act) / x_act )**2 # Element-wise relative error
# #     rmse = np.sqrt(np.mean(error, axis=1))  # RMSE for each simulation
# #     nrmse = np.mean(rmse)  # Mean of RMSE across simulations
# #     return nrmse




# Plots.
#########-------STATE ERROR TRAJECTORY------########
##The error is calculated based on the normalised x and xhat values
#The error is summed across columns to get the error across the states for 1 time step(sum_err)
#The trajectory is the total error for each time step, so it's a plot of 50 values
err3 = np.abs( xsim[ :Nsim, :16] - xhat[ :Nsim, :16] )
# err3 = np.abs( actual_pars - estimated_pars )
sum_err3 = err3.sum(axis=1)
fig_0 = plt.plot(tplot[:-1], sum_err3)
plt.ylabel('overall error')
plt.xlabel('Time')

np.save('sum_error3' , sum_err3)
P_true = p_true*np.ones((50,5))
np.save('p_true', P_true)




#########-------PARAMETER ERROR TRAJECTORY------########
# Perr3 = np.abs(xsim[:50, 16:] - xhat[:50, 16:])
Perr3 = np.abs( actual_pars - estimated_pars )
sum_Perr3 = Perr3.sum(axis=1)
fig_1 = plt.plot(tplot[:-1], sum_Perr3)
plt.ylabel('overall error')
plt.xlabel('Time')

np.save('sum_Perror3' , sum_Perr3)

# # rmse_value_p = rmse_param(actual_values_p, predicted_values_p)
# rmse_value_p = rmse_param(xsim[:50, 16:], xhat[:50,16:])
# print("RMSE_parameters:", rmse_value_p)

# xsim_fixed_pars = 0.65*np.ones([50,13])
# xhat_fixed_pars = xhat[0,16]*np.ones([50,13])

# actual_pars = np.concatenate((xsim[:50,16:], xsim_fixed_pars),axis=1)
# estimated_pars = np.concatenate((xhat[:50,16:], xhat_fixed_pars), axis=1)
# rmse_value_p = rmse_param(actual_pars, estimated_pars)
# print("RMSE_parameters:", rmse_value_p)




# #----------------------STATES----------------------------------#
fig1, axs = plt.subplots(4, 2)
axs[0, 0].plot(tplot[:-1], x_actual[:,0], 'tab:blue', linewidth=3)
axs[0, 0].plot(tplot[:-1], x_actual_hat[:,0], '--r', linewidth=3)
axs[0, 0].set_ylabel('Xv_1 (cell/L)')

axs[0, 1].plot(tplot[:-1], x_actual[:,1], 'tab:blue', linewidth=3)
axs[0, 1].plot(tplot[:-1], x_actual_hat[:,1],  '--r' , linewidth=3)
axs[0, 1].set_ylabel('Xt1 (cell/L)')

axs[1, 0].plot(tplot[:-1], x_actual[:,2], 'tab:blue' , linewidth=3)
axs[1, 0].plot(tplot[:-1], x_actual_hat[:,2],  '--r', linewidth=3)
axs[1, 0].set_ylabel('GLC1 (mM)')

axs[1, 1].plot(tplot[:-1], x_actual[:,3], 'tab:blue', linewidth=3)
axs[1, 1].plot(tplot[:-1], x_actual_hat[:,3],  '--r', linewidth=3)
axs[1, 1].set_ylabel('GLN1 (mM)')

axs[2, 0].plot(tplot[:-1], x_actual[:,4], 'tab:blue', linewidth=3)
axs[2, 0].plot(tplot[:-1], x_actual_hat[:,4],  '--r', linewidth=3)
axs[2, 0].set_ylabel('LAC1 (mM)')

axs[2, 1].plot(tplot[:-1], x_actual[:,5], 'tab:blue', linewidth=3)
axs[2, 1].plot(tplot[:-1], x_actual_hat[:,5],  '--r', linewidth=3)
axs[2, 1].set_ylabel('AMM1 (mM)')

axs[3, 0].plot(tplot[:-1], x_actual[:,6], 'tab:blue', linewidth=3)
axs[3, 0].plot(tplot[:-1], x_actual_hat[:,6],  '--r', linewidth=3)
axs[3, 0].set_ylabel('mAb1 (mg/L)')

axs[3, 1].plot(tplot[:-1], x_actual[:,7], 'tab:blue', linewidth=3)
axs[3, 1].plot(tplot[:-1], x_actual_hat[:,7], '--r', linewidth=3)
axs[3, 1].set_ylabel('Xv2 (cell/L)')

fig1.legend(['Actual', 'Estimate'])
plt.rcParams["figure.figsize"] = [8.00, 7.00]
plt.setp(axs[-1, :], xlabel='Time(min)')




fig2, axs = plt.subplots(4, 2)
axs[0, 0].plot(tplot[:-1], x_actual[:,8], 'tab:blue', linewidth=3)
axs[0, 0].plot(tplot[:-1], x_actual_hat[:,8], '--r', linewidth=3)
axs[0, 0].set_ylabel('Xt2 (cell/L)')

axs[0, 1].plot(tplot[:-1], x_actual[:,9], 'tab:blue', linewidth=3)
axs[0, 1].plot(tplot[:-1], x_actual_hat[:,9], '--r', linewidth=3)
axs[0, 1].set_ylabel('GLC2 (mM)')

axs[1, 0].plot(tplot[:-1], x_actual[:,10], 'tab:blue', linewidth=3)
axs[1, 0].plot(tplot[:-1], x_actual_hat[:,10], '--r', linewidth=3)
axs[1, 0].set_ylabel('GLN2 (mM)')

axs[1, 1].plot(tplot[:-1], x_actual[:,11], 'tab:blue',linewidth=3)
axs[1, 1].plot(tplot[:-1], x_actual_hat[:,11], '--r', linewidth=3)
axs[1, 1].set_ylabel('LAC2 (mM)')

axs[2, 0].plot(tplot[:-1], x_actual[:,12], 'tab:blue', linewidth=3)
axs[2, 0].plot(tplot[:-1], x_actual_hat[:,12], '--r', linewidth=3)
axs[2, 0].set_ylabel('AMM2 (mM)')

axs[2, 1].plot(tplot[:-1], x_actual[:,13], 'tab:blue', linewidth=3)
axs[2, 1].plot(tplot[:-1], x_actual_hat[:,13], '--r', linewidth=3)
axs[2, 1].set_ylabel('mAb2 (mg/L)')

axs[3, 0].plot(tplot[:-1], x_actual[:,14], 'tab:blue', linewidth=3)
axs[3, 0].plot(tplot[:-1], x_actual_hat[:,14], '--r', linewidth=3)
axs[3, 0].set_ylabel('T (d C)')

axs[3, 1].plot(tplot[:-1], x_actual[:,15], 'tab:blue', linewidth=3)
axs[3, 1].plot(tplot[:-1], x_actual_hat[:,15], '--r', linewidth=3)
axs[3, 1].set_ylabel('c_buffer (mg/L)')

fig2.legend(['Actual', 'Estimate'])
plt.rcParams["figure.figsize"] = [8.00, 7.00]
plt.setp(axs[-1, :], xlabel='Time(min)')




# #-----------------------parameters------------------------------------#
fig3, axs = plt.subplots(5, 1)
axs[0].plot(tplot[:-8], x_actual[:-7,16], 'tab:blue', linewidth=3)
axs[0].plot(tplot[:-8], x_actual_hat[:-7,16], '--r' , linewidth=3)
axs[0].set_ylabel('m_mabx (mg/Cell/min)')

axs[1].plot(tplot[:-8], x_actual[:-7,17], 'tab:blue', linewidth=3)
axs[1].plot(tplot[:-8], x_actual_hat[:-7,17], '--r' , linewidth=3)
axs[1].set_ylabel('K_damm (mM)')

axs[2].plot(tplot[:-8], x_actual[:-7,18], 'tab:blue', linewidth=3)
axs[2].plot(tplot[:-8], x_actual_hat[:-7,18], '--r' , linewidth=3)
axs[2].set_ylabel('rho (g/L)')

axs[3].plot(tplot[:-8], x_actual[:-7,19], 'tab:blue', linewidth=3)
axs[3].plot(tplot[:-8], x_actual_hat[:-7,19], '--r' , linewidth=3)
axs[3].set_ylabel('mu_dmax')

axs[4].plot(tplot[:-8], x_actual[:-7,20], 'tab:blue', linewidth=3)
axs[4].plot(tplot[:-8], x_actual_hat[:-7,20], '--r' , linewidth=3)
axs[4].set_ylabel('K_gln (mM)')

fig3.legend(['Actual', 'Estimate'])
plt.rcParams["figure.figsize"] = [8.00, 7.00]
plt.xlabel('Time(min)')































